package com.njcb.ams.factory.domain;

import com.njcb.ams.support.exception.ExceptionCode;
import com.njcb.ams.support.exception.ExceptionUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.AutowireCapableBeanFactory;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.beans.factory.support.AbstractAutowireCapableBeanFactory;
import org.springframework.beans.factory.support.AbstractBeanDefinition;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.stereotype.Component;

import java.util.Arrays;

/**
 * @author liuyanlong
 */
@Component
public class AppContext {
    private static final Logger logger = LoggerFactory.getLogger(AppContext.class);

    private static ApplicationContext context;

    public static ApplicationContext getContext() {
        return context;
    }

    /**
     * 从容器中获取此beanId的Bean
     * @param beanId 对象ID
     * @return 容器中的对象
     */
    public static Object getBean(String beanId) {
        if (null == context) {
            return null;
        }
        logger.trace("beanId : {}", beanId);
        Object bean = (Object) context.getBean(beanId);
        if (bean == null) {
            logger.error("bean id = {} not found", beanId);
            throw new RuntimeException("bean id = " + beanId + " not found");
        }
        return bean;
    }

    /**
     * 类名首字母小写作为beanId
     * @param clazz 对象类信息
     * @return T 对象
     */
    @SuppressWarnings("unchecked")
    public static <T> T getBean(Class<T> clazz) {
        char[] ca = clazz.getSimpleName().toCharArray();
        ca[0] = Character.toLowerCase(ca[0]);
        T bean = (T) getBean(String.valueOf(ca));
        return bean;
    }

    public static void setContext(ApplicationContext ctx) {
        context = ctx;
    }

    /**
     * 判断容器上下文是否包含有此beanId的bean
     *
     * @param beanId 对象ID
     * @return 是否存在
     */
    public static boolean containBean(String beanId) {
        return context.containsBean(beanId);
    }

    /**
     * 注入Bean
     * @param beanName 对象名称
     * @param <T> 对象类型
     * @param beanClass 对象类信息
     * @return 对象
     */
    public static <T> T injectBean(String beanName, Class<T> beanClass) {
        AutowireCapableBeanFactory autowireCapableBeanFactory = context.getAutowireCapableBeanFactory();
        DefaultListableBeanFactory defaultListableBeanFactory = null;
        if(autowireCapableBeanFactory instanceof DefaultListableBeanFactory){
            defaultListableBeanFactory = (DefaultListableBeanFactory) autowireCapableBeanFactory;
        }else{
            ExceptionUtil.throwAppException("暂未支持此功能", ExceptionCode.DEFAULT_EXCEPTION);
        }
        boolean containsBean = defaultListableBeanFactory.containsBean(beanName);
        if (containsBean) {
            defaultListableBeanFactory.removeBeanDefinition(beanName);
        }

        //默认使用对象的无参构造方法：此对象已经重写无参构造（此处会重新实例化一个新对象）
        BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder.genericBeanDefinition(beanClass);
        //构造注入对象参数
        AbstractBeanDefinition beanDefinition = beanDefinitionBuilder.getBeanDefinition();
        //按类型注入
        beanDefinition.setAutowireMode(AutowireCapableBeanFactory.AUTOWIRE_BY_TYPE);
        beanDefinition.setPrimary(true);
        beanDefinition.setSynthetic(true);
        //应用重新
        beanDefinition.setRole(BeanDefinition.ROLE_APPLICATION);
        //单例
        beanDefinition.setScope(ConfigurableBeanFactory.SCOPE_SINGLETON);

        defaultListableBeanFactory.registerBeanDefinition(beanName, beanDefinition);
        T bean = context.getBean(beanClass);
        return bean;
    }

}
